#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=12345
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=8

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

BASE_PATH="yourpath" # to change
VERSION="3b"
DATASET="SQuAD"
SAVE_NAME="$1-random" # to change

mkdir -p /data/checkpoints/results/${DATASET}-${SAVE_NAME}

OPTS=""
OPTS+=" --dataset ${DATASET}"
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-config yourpath/t5-${VERSION}" # to change
OPTS+=" --batch-size 4"
OPTS+=" --train-iters 6000"
OPTS+=" --save-iters 1000"
OPTS+=" --max-encoder-length 512"
OPTS+=" --max-decoder-length 32"
OPTS+=" --save /data/checkpoints/results/${DATASET}-${SAVE_NAME}"
OPTS+=" --save-name ${DATASET}-${SAVE_NAME}"
OPTS+=" --lr 0.00001"
OPTS+=" --inspect-iters 100"
OPTS+=" --warmup-iters 3000"
OPTS+=" --lr-decay-style constant"
OPTS+=" --weight-decay 1e-2"
OPTS+=" --clip-grad 10.0"
OPTS+=" --loss-scale 1048576"
OPTS+=" --pet True"
OPTS+=" --comp-type $1"
OPTS+=" --pet-init-type random"
OPTS+=" --recover False"
OPTS+=" --distill False"
OPTS+=" --quant-ckpt-path yourpath/t5-3b-test-q/checkpoint.pt"
OPTS+=" --moe-ckpt-path yourpath/param_split/" 
OPTS+=" --pr-ckpt-path yourpath/t5-3b-test-d-p/checkpoint.pt"
OPTS+=" --spr-ckpt-path yourpath/t5-3b-test-d-sp/checkpoints/ckpt-100000.pt"
OPTS+=" --model-ckpt-path yourpath/t5-${VERSION}/pytorch_model.pt"
OPTS+=" --mix-ckpt-path yourpath/t5-3b-test-q-d-p/checkpoint.pt"
OPTS+=" --inherit-ckpt-path /data/checkpoints/results/${DATASET}-${DATASET}/${DATASET}-${DATASET}-3.pt" # to change
OPTS+=" --quant-config-path ${BASE_PATH}/examples/t5/quant_config.json"
OPTS+=" --pr-config-path ${BASE_PATH}/examples/t5/prune_config.json"
OPTS+=" --spr-config-path ${BASE_PATH}/examples/t5/sprune_config.json"
OPTS+=" --mix-layer-ckpt-path yourpath/param_split/"

CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}/examples/t5/finetune_t5_squad.py ${OPTS}"
# CMD="python3 -m torch.distributed.launch ${DISTRIBUTED_ARGS} ${BASE_PATH}/examples/t5/infer_t5_squad.py ${OPTS}"
echo ${CMD}

${CMD} 2>&1  | tee /data/checkpoints/logs/t5_squad/${DATASET}-${SAVE_NAME}.log
